1   package org.apache.lucene.search.join;
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  import java.io.IOException;
21  import java.util.Arrays;
22  import java.util.HashMap;
23  import java.util.LinkedList;
24  import java.util.Map;
25  import java.util.Queue;
26  
27  import org.apache.lucene.index.IndexWriter;
28  import org.apache.lucene.index.LeafReaderContext;
29  import org.apache.lucene.search.Collector;
30  import org.apache.lucene.search.FieldComparator;
31  import org.apache.lucene.search.FieldValueHitQueue;
32  import org.apache.lucene.search.LeafCollector;
33  import org.apache.lucene.search.LeafFieldComparator;
34  import org.apache.lucene.search.Query;
35  import org.apache.lucene.search.ScoreCachingWrappingScorer;
36  import org.apache.lucene.search.Scorer;
37  import org.apache.lucene.search.Scorer.ChildScorer;
38  import org.apache.lucene.search.Sort;
39  import org.apache.lucene.search.TopDocs;
40  import org.apache.lucene.search.TopDocsCollector;
41  import org.apache.lucene.search.TopFieldCollector;
42  import org.apache.lucene.search.TopScoreDocCollector;
43  import org.apache.lucene.search.grouping.GroupDocs;
44  import org.apache.lucene.search.grouping.TopGroups;
45  import org.apache.lucene.util.ArrayUtil;
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  public class ToParentBlockJoinCollector implements Collector {
97  
98    private final Sort sort;
99  
100   
101   
102   private final Map<Query,Integer> joinQueryID = new HashMap<>();
103   private final int numParentHits;
104   private final FieldValueHitQueue<OneGroup> queue;
105   private final FieldComparator<?>[] comparators;
106   private final boolean trackMaxScore;
107   private final boolean trackScores;
108 
109   private ToParentBlockJoinQuery.BlockJoinScorer[] joinScorers = new ToParentBlockJoinQuery.BlockJoinScorer[0];
110   private boolean queueFull;
111 
112   private OneGroup bottom;
113   private int totalHitCount;
114   private float maxScore = Float.NaN;
115 
116   
117 
118 
119 
120   public ToParentBlockJoinCollector(Sort sort, int numParentHits, boolean trackScores, boolean trackMaxScore) throws IOException {
121     
122     
123     this.sort = sort;
124     this.trackMaxScore = trackMaxScore;
125     if (trackMaxScore) {
126       maxScore = Float.MIN_VALUE;
127     }
128     
129     this.trackScores = trackScores;
130     this.numParentHits = numParentHits;
131     queue = FieldValueHitQueue.create(sort.getSort(), numParentHits);
132     comparators = queue.getComparators();
133   }
134   
135   private static final class OneGroup extends FieldValueHitQueue.Entry {
136     public OneGroup(int comparatorSlot, int parentDoc, float parentScore, int numJoins, boolean doScores) {
137       super(comparatorSlot, parentDoc, parentScore);
138       
139       docs = new int[numJoins][];
140       for(int joinID=0;joinID<numJoins;joinID++) {
141         docs[joinID] = new int[5];
142       }
143       if (doScores) {
144         scores = new float[numJoins][];
145         for(int joinID=0;joinID<numJoins;joinID++) {
146           scores[joinID] = new float[5];
147         }
148       }
149       counts = new int[numJoins];
150     }
151     LeafReaderContext readerContext;
152     int[][] docs;
153     float[][] scores;
154     int[] counts;
155   }
156 
157   @Override
158   public LeafCollector getLeafCollector(final LeafReaderContext context)
159       throws IOException {
160     final LeafFieldComparator[] comparators = queue.getComparators(context);
161     final int[] reverseMul = queue.getReverseMul();
162     final int docBase = context.docBase;
163     return new LeafCollector() {
164 
165       private Scorer scorer;
166 
167       @Override
168       public void setScorer(Scorer scorer) throws IOException {
169         
170         
171         
172         
173         if (scorer instanceof ScoreCachingWrappingScorer == false) {
174           scorer = new ScoreCachingWrappingScorer(scorer);
175         }
176         this.scorer = scorer;
177         for (LeafFieldComparator comparator : comparators) {
178           comparator.setScorer(scorer);
179         }
180         Arrays.fill(joinScorers, null);
181 
182         Queue<Scorer> queue = new LinkedList<>();
183         
184         queue.add(scorer);
185         while ((scorer = queue.poll()) != null) {
186           
187           if (scorer instanceof ToParentBlockJoinQuery.BlockJoinScorer) {
188             enroll((ToParentBlockJoinQuery) scorer.getWeight().getQuery(), (ToParentBlockJoinQuery.BlockJoinScorer) scorer);
189           }
190 
191           for (ChildScorer sub : scorer.getChildren()) {
192             
193             queue.add(sub.child);
194           }
195         }
196       }
197       
198       @Override
199       public void collect(int parentDoc) throws IOException {
200       
201         totalHitCount++;
202 
203         float score = Float.NaN;
204 
205         if (trackMaxScore) {
206           score = scorer.score();
207           maxScore = Math.max(maxScore, score);
208         }
209 
210         
211         
212         
213 
214         if (queueFull) {
215           
216           
217           int c = 0;
218           for (int i = 0; i < comparators.length; ++i) {
219             c = reverseMul[i] * comparators[i].compareBottom(parentDoc);
220             if (c != 0) {
221               break;
222             }
223           }
224           if (c <= 0) { 
225             
226             
227             return;
228           }
229 
230           
231 
232           
233           for (LeafFieldComparator comparator : comparators) {
234             comparator.copy(bottom.slot, parentDoc);
235           }
236           if (!trackMaxScore && trackScores) {
237             score = scorer.score();
238           }
239           bottom.doc = docBase + parentDoc;
240           bottom.readerContext = context;
241           bottom.score = score;
242           copyGroups(bottom);
243           bottom = queue.updateTop();
244 
245           for (LeafFieldComparator comparator : comparators) {
246             comparator.setBottom(bottom.slot);
247           }
248         } else {
249           
250           final int comparatorSlot = totalHitCount - 1;
251 
252           
253           for (LeafFieldComparator comparator : comparators) {
254             comparator.copy(comparatorSlot, parentDoc);
255           }
256           
257           if (!trackMaxScore && trackScores) {
258             score = scorer.score();
259           }
260           final OneGroup og = new OneGroup(comparatorSlot, docBase+parentDoc, score, joinScorers.length, trackScores);
261           og.readerContext = context;
262           copyGroups(og);
263           bottom = queue.add(og);
264           queueFull = totalHitCount == numParentHits;
265           if (queueFull) {
266             
267             for (LeafFieldComparator comparator : comparators) {
268               comparator.setBottom(bottom.slot);
269             }
270           }
271         }
272       }
273       
274       
275       private void copyGroups(OneGroup og) {
276         
277         
278         
279         final int numSubScorers = joinScorers.length;
280         if (og.docs.length < numSubScorers) {
281           
282           
283           
284           og.docs = ArrayUtil.grow(og.docs);
285         }
286         if (og.counts.length < numSubScorers) {
287           og.counts = ArrayUtil.grow(og.counts);
288         }
289         if (trackScores && og.scores.length < numSubScorers) {
290           og.scores = ArrayUtil.grow(og.scores);
291         }
292 
293         
294         for(int scorerIDX = 0;scorerIDX < numSubScorers;scorerIDX++) {
295           final ToParentBlockJoinQuery.BlockJoinScorer joinScorer = joinScorers[scorerIDX];
296           
297           if (joinScorer != null && docBase + joinScorer.getParentDoc() == og.doc) {
298             og.counts[scorerIDX] = joinScorer.getChildCount();
299             
300             og.docs[scorerIDX] = joinScorer.swapChildDocs(og.docs[scorerIDX]);
301             assert og.docs[scorerIDX].length >= og.counts[scorerIDX]: "length=" + og.docs[scorerIDX].length + " vs count=" + og.counts[scorerIDX];
302             
303             
304 
305 
306 
307 
308             if (trackScores) {
309               
310               og.scores[scorerIDX] = joinScorer.swapChildScores(og.scores[scorerIDX]);
311               assert og.scores[scorerIDX].length >= og.counts[scorerIDX]: "length=" + og.scores[scorerIDX].length + " vs count=" + og.counts[scorerIDX];
312             }
313           } else {
314             og.counts[scorerIDX] = 0;
315           }
316         }
317       }
318     };
319   }
320 
321   private void enroll(ToParentBlockJoinQuery query, ToParentBlockJoinQuery.BlockJoinScorer scorer) {
322     scorer.trackPendingChildHits();
323     final Integer slot = joinQueryID.get(query);
324     if (slot == null) {
325       joinQueryID.put(query, joinScorers.length);
326       
327       final ToParentBlockJoinQuery.BlockJoinScorer[] newArray = new ToParentBlockJoinQuery.BlockJoinScorer[1+joinScorers.length];
328       System.arraycopy(joinScorers, 0, newArray, 0, joinScorers.length);
329       joinScorers = newArray;
330       joinScorers[joinScorers.length-1] = scorer;
331     } else {
332       joinScorers[slot] = scorer;
333     }
334   }
335 
336   private OneGroup[] sortedGroups;
337 
338   private void sortQueue() {
339     sortedGroups = new OneGroup[queue.size()];
340     for(int downTo=queue.size()-1;downTo>=0;downTo--) {
341       sortedGroups[downTo] = queue.pop();
342     }
343   }
344 
345   
346 
347 
348 
349 
350 
351 
352 
353 
354 
355 
356 
357 
358 
359 
360 
361   public TopGroups<Integer> getTopGroups(ToParentBlockJoinQuery query, Sort withinGroupSort, int offset,
362                                          int maxDocsPerGroup, int withinGroupOffset, boolean fillSortFields)
363     throws IOException {
364 
365     final Integer _slot = joinQueryID.get(query);
366     if (_slot == null && totalHitCount == 0) {
367       return null;
368     }
369 
370     if (sortedGroups == null) {
371       if (offset >= queue.size()) {
372         return null;
373       }
374       sortQueue();
375     } else if (offset > sortedGroups.length) {
376       return null;
377     }
378 
379     return accumulateGroups(_slot == null ? -1 : _slot.intValue(), offset, maxDocsPerGroup, withinGroupOffset, withinGroupSort, fillSortFields);
380   }
381 
382   
383 
384 
385 
386 
387 
388 
389 
390 
391 
392 
393 
394   @SuppressWarnings({"unchecked","rawtypes"})
395   private TopGroups<Integer> accumulateGroups(int slot, int offset, int maxDocsPerGroup,
396                                               int withinGroupOffset, Sort withinGroupSort, boolean fillSortFields) throws IOException {
397     final GroupDocs<Integer>[] groups = new GroupDocs[sortedGroups.length - offset];
398     final FakeScorer fakeScorer = new FakeScorer();
399 
400     int totalGroupedHitCount = 0;
401     
402 
403     for(int groupIDX=offset;groupIDX<sortedGroups.length;groupIDX++) {
404       final OneGroup og = sortedGroups[groupIDX];
405       final int numChildDocs;
406       if (slot == -1 || slot >= og.counts.length) {
407         numChildDocs = 0;
408       } else {
409         numChildDocs = og.counts[slot];
410       }
411 
412       
413       final int numDocsInGroup = Math.max(1, Math.min(numChildDocs, maxDocsPerGroup));
414       
415 
416       
417       
418       final TopDocsCollector<?> collector;
419       if (withinGroupSort == null) {
420         
421         
422         if (!trackScores) {
423           throw new IllegalArgumentException("cannot sort by relevance within group: trackScores=false");
424         }
425         collector = TopScoreDocCollector.create(numDocsInGroup);
426       } else {
427         
428         collector = TopFieldCollector.create(withinGroupSort, numDocsInGroup, fillSortFields, trackScores, trackMaxScore);
429       }
430 
431       LeafCollector leafCollector = collector.getLeafCollector(og.readerContext);
432       leafCollector.setScorer(fakeScorer);
433       for(int docIDX=0;docIDX<numChildDocs;docIDX++) {
434         
435         final int doc = og.docs[slot][docIDX];
436         fakeScorer.doc = doc;
437         if (trackScores) {
438           fakeScorer.score = og.scores[slot][docIDX];
439         }
440         leafCollector.collect(doc);
441       }
442       totalGroupedHitCount += numChildDocs;
443 
444       final Object[] groupSortValues;
445 
446       if (fillSortFields) {
447         groupSortValues = new Object[comparators.length];
448         for(int sortFieldIDX=0;sortFieldIDX<comparators.length;sortFieldIDX++) {
449           groupSortValues[sortFieldIDX] = comparators[sortFieldIDX].value(og.slot);
450         }
451       } else {
452         groupSortValues = null;
453       }
454 
455       final TopDocs topDocs = collector.topDocs(withinGroupOffset, numDocsInGroup);
456 
457       groups[groupIDX-offset] = new GroupDocs<>(og.score,
458                                                        topDocs.getMaxScore(),
459                                                        numChildDocs,
460                                                        topDocs.scoreDocs,
461                                                        og.doc,
462                                                        groupSortValues);
463     }
464 
465     return new TopGroups<>(new TopGroups<>(sort.getSort(),
466                                                        withinGroupSort == null ? null : withinGroupSort.getSort(),
467                                                        0, totalGroupedHitCount, groups, maxScore),
468                                   totalHitCount);
469   }
470 
471   
472 
473 
474 
475 
476 
477 
478 
479 
480 
481 
482 
483 
484 
485   public TopGroups<Integer> getTopGroupsWithAllChildDocs(ToParentBlockJoinQuery query, Sort withinGroupSort, int offset,
486                                                          int withinGroupOffset, boolean fillSortFields)
487     throws IOException {
488 
489     return getTopGroups(query, withinGroupSort, offset, Integer.MAX_VALUE, withinGroupOffset, fillSortFields);
490   }
491   
492   
493 
494 
495 
496 
497 
498   public float getMaxScore() {
499     return maxScore;
500   }
501 
502   @Override
503   public boolean needsScores() {
504     
505     
506     return true;
507   }
508 }